{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "char_sum = [c for c in 'SEPabcdefghijklmnopqrstuvwxyz']#定义所有可能出现的char S、E、P是用来填充和界定边界的\n",
    "dic_char = {c:i for i,c in enumerate(char_sum)}#定义字符对应数字的dic\n",
    "data = [['man','woman'],['small','large'],['in','out'],['young','old'],['black','white'],['far','near'],['short','long'],['king','queen'],['up','down']]\n",
    "#定义一堆反义词 seq2seq的目的就是输出反义词\n",
    "\n",
    "#定义一下超参数\n",
    "n_step=5#最长的词长度不超过5\n",
    "num_class=len(char_sum)#相当于char的种类数\n",
    "n_hidden = 128#每个RNN的隐藏单元个数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def get_batch(data):\n",
    "    input_batch,output_batch,target_batch = [],[],[]\n",
    "    for seq in data:#做一个padding的操作\n",
    "        for i in range(2):#因为只有原词和反义词 所以用的2\n",
    "            seq[i] = seq[i]+'P'*(n_step-len(seq[i]))#把长度不足5的单词 padding到5\n",
    "        input_ = [dic_char[n] for n in seq[0]]#['m','a','n','P','P']->['15','3','16','2','2']\n",
    "        output_ = [dic_char[n] for n in ('S' + seq[1])]#seq2seq要求target前有一位\n",
    "        target_ = [dic_char[n] for n in (seq[1] + 'E')]#[25, 17, 15, 3, 16, 1]\n",
    "        #用one-hot表示各个字符\n",
    "        input_batch.append(np.eye(num_class)[input_])#从一个对角矩阵中找ont-hot向量\n",
    "        output_batch.append(np.eye(num_class)[output_])#\n",
    "        target_batch.append(target_)\n",
    "    #print(input_batch[0].shape)#（5，29）5个char 每个char由一个29维的one-hot表示\n",
    "    return input_batch,output_batch,target_batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1000:\n",
      "cost:0.001274\n",
      "Epoch 2000:\n",
      "cost:0.000304\n",
      "Epoch 3000:\n",
      "cost:0.000131\n",
      "Epoch 4000:\n",
      "cost:0.000125\n",
      "Epoch 5000:\n",
      "cost:0.000094\n",
      "test\n",
      "man -> woman\n",
      "mans -> woman\n",
      "king -> queen\n",
      "black -> white\n",
      "up -> downP\n"
     ]
    }
   ],
   "source": [
    "#定义model\n",
    "encoder_input = tf.placeholder(tf.float32,[None,None,num_class])#(batch_size,max_len,num_class)\n",
    "decoder_input = tf.placeholder(tf.float32,[None,None,num_class])#(batch_size,max_len+1,num_class) max_len+1 because'S'\n",
    "target = tf.placeholder(tf.int32,[None,None])#(batch_size,max_len+1)max_len+1 because'E'\n",
    "\n",
    "#encoder\n",
    "with tf.variable_scope('encoder'):\n",
    "    encoder = tf.contrib.rnn.BasicRNNCell(n_hidden)#根据每个RNN的隐藏单元个数 创建RNN cell\n",
    "    encoder = tf.contrib.rnn.DropoutWrapper(encoder,output_keep_prob=0.5)#dropout 随机失活\n",
    "    _,encoder_output = tf.nn.dynamic_rnn(encoder,encoder_input,dtype=tf.float32)\n",
    "\n",
    "#decoder\n",
    "with tf.variable_scope('decoder'):\n",
    "    decoder = tf.contrib.rnn.BasicRNNCell(n_hidden)#根据每个RNN的隐藏单元个数 创建RNN cell\n",
    "    decoder = tf.contrib.rnn.DropoutWrapper(decoder,output_keep_prob=0.5)#dropout 随机失活\n",
    "    decoder_output,_ = tf.nn.dynamic_rnn(decoder,decoder_input,initial_state=encoder_output,dtype=tf.float32)\n",
    "\n",
    "#还需要过一个全连接 把得到的值映射到29个类别\n",
    "model = tf.layers.dense(decoder_output,num_class,activation=None)\n",
    "\n",
    "#计算loss 以及优化\n",
    "cost = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=model,labels=target))\n",
    "optimizer = tf.train.AdamOptimizer(0.001).minimize(cost)\n",
    "\n",
    "#定义测试函数 测试的是一个单词\n",
    "def test(word):\n",
    "    seq_data = [word,'P'*len(word)]\n",
    "    input_batch,output_batch,_ = get_batch([seq_data])\n",
    "    predict = tf.argmax(model,2) # model : [batch_size, max_len+1, n_class]\n",
    "    res = sess.run(predict,feed_dict={encoder_input:input_batch,decoder_input:output_batch})\n",
    "    decoded = [char_sum[i] for i in res[0]]\n",
    "    end = decoded.index('E')\n",
    "    translated = ''.join(decoded[:end])\n",
    "    return translated\n",
    "\n",
    "\n",
    "sess = tf.Session()\n",
    "#初始化\n",
    "sess.run(tf.global_variables_initializer())\n",
    "#准备batch\n",
    "input_batch,output_batch,target_batch = get_batch(data)\n",
    "for i in range(5000):#跑5000轮\n",
    "    _,cost_ = sess.run([optimizer,cost],\n",
    "                        feed_dict={encoder_input:input_batch,decoder_input:output_batch,target:target_batch})\n",
    "    if (i+1)%1000 == 0:\n",
    "        print('Epoch %04d:' % (i+1))\n",
    "        print('cost:%.6f' % cost_)\n",
    "print('test')\n",
    "print('man ->', test('man'))\n",
    "print('mans ->', test('mans'))\n",
    "print('king ->', test('king'))\n",
    "print('black ->', test('black'))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
